-
Notifications
You must be signed in to change notification settings - Fork 56
Add support for TransformerEngine flash attention in WAN #299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
|
@cpersson-amd I've been out on PTO for a month. I'll take a closer look at this next week. Meanwhile, can you update your branch with the latest in main. Thanks. |
entrpn
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general the PR looks good, but I'm still unsure if adding another axes, fsdp_batch, is really necessary. I would prefer not to add it. The other major thing is switching the mesh_axes from data, fsdp, tensor to data, tensor, fsdp.
|
@susanbao can you take a quick look at this PR. |
|
@cpersson-amd please review Sanbao's comments above and rebase with main. We tested the PR internally and it looks good. Would you be willing to change the axis fsdp to context? If not, I can make the change after this PR is merged. |
thanks @cpersson-amd this looks great. Can you run |
|
@entrpn Sure, I ran 'ruff check --fix' and had to manually fix some bare except statements. It should be good with the latest commit |
|
@cpersson-amd Please review my PR to fix some of the unit tests. Once they pass, this can be merged. cpersson-amd#1 |
This PR implements the following:
The code has been tested on WAN 2.1 (training and inference) and flux (only training) using GPUs.